{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ac786fd8-74ed-4c86-844e-2468690a3ccb",
   "metadata": {},
   "source": [
    "## 一、安装所需的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2160de36-4297-462c-90fb-636595875981",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple/\n",
      "Requirement already satisfied: mindnlp in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (0.4.1)\n",
      "Requirement already satisfied: mindspore>=2.2.14 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.5.0)\n",
      "Requirement already satisfied: tqdm in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.67.1)\n",
      "Requirement already satisfied: requests in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)\n",
      "Requirement already satisfied: datasets in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (3.6.0)\n",
      "Requirement already satisfied: evaluate in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.4.3)\n",
      "Requirement already satisfied: tokenizers==0.19.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.19.1)\n",
      "Requirement already satisfied: safetensors in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.5.3)\n",
      "Requirement already satisfied: sentencepiece in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.2.0)\n",
      "Requirement already satisfied: regex in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2024.11.6)\n",
      "Requirement already satisfied: addict in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.4.0)\n",
      "Requirement already satisfied: ml-dtypes in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.5.1)\n",
      "Requirement already satisfied: pyctcdecode in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (0.5.0)\n",
      "Requirement already satisfied: pytest==7.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (7.2.0)\n",
      "Requirement already satisfied: pillow>=10.0.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (11.2.1)\n",
      "Requirement already satisfied: attrs>=19.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (25.3.0)\n",
      "Requirement already satisfied: iniconfig in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.1.0)\n",
      "Requirement already satisfied: packaging in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (24.2)\n",
      "Requirement already satisfied: pluggy<2.0,>=0.12 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)\n",
      "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)\n",
      "Requirement already satisfied: tomli>=1.0.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.2.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from tokenizers==0.19.1->mindnlp) (0.31.4)\n",
      "Requirement already satisfied: filelock in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp) (3.18.0)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp) (2025.3.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp) (6.0.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp) (4.12.2)\n",
      "Requirement already satisfied: numpy<2.0.0,>=1.20.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.26.4)\n",
      "Requirement already satisfied: protobuf>=3.13.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (6.30.2)\n",
      "Requirement already satisfied: asttokens>=2.0.4 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (3.0.0)\n",
      "Requirement already satisfied: scipy>=1.5.4 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.13.1)\n",
      "Requirement already satisfied: psutil>=5.6.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (5.9.0)\n",
      "Requirement already satisfied: astunparse>=1.6.3 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.6.3)\n",
      "Requirement already satisfied: dill>=0.3.7 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (0.3.8)\n",
      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore>=2.2.14->mindnlp) (0.45.1)\n",
      "Requirement already satisfied: six<2.0,>=1.6.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore>=2.2.14->mindnlp) (1.17.0)\n",
      "Requirement already satisfied: pyarrow>=15.0.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (20.0.0)\n",
      "Requirement already satisfied: pandas in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.3)\n",
      "Requirement already satisfied: xxhash in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.5.0)\n",
      "Requirement already satisfied: multiprocess<0.70.17 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.70.16)\n",
      "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (3.11.18)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (2.6.1)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (1.3.2)\n",
      "Requirement already satisfied: async-timeout<6.0,>=4.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (5.0.1)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (1.6.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (6.4.4)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (0.3.1)\n",
      "Requirement already satisfied: yarl<2.0,>=1.17.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (1.20.0)\n",
      "Requirement already satisfied: idna>=2.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets->mindnlp) (3.10)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.4.1)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.4.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2025.4.26)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2025.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2025.2)\n",
      "Requirement already satisfied: pygtrie<3.0,>=2.1 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (2.5.0)\n",
      "Requirement already satisfied: hypothesis<7,>=6.14 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from pyctcdecode->mindnlp) (6.131.21)\n",
      "Requirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages (from hypothesis<7,>=6.14->pyctcdecode->mindnlp) (2.4.0)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m25.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.1.1\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install mindnlp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79c97c5a-bb98-4cdc-b5e9-4f1828192b21",
   "metadata": {},
   "source": [
    "## 二、加载数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14adb23e-73aa-4f70-94e7-78cc8104700b",
   "metadata": {},
   "source": [
    "首先，导入了必要的库，包括 mindnlp.dataset 中的 load_dataset 和用于保证可复现性的 random。接着加载了 IMDB 数据集，并将其划分为训练集和测试集。最后，代码将原始训练集进一步划分为一个新的、较小的训练集（占70%）和一个验证集（占30%）。\n",
    "这里提到的 IMDB 数据集是一个广泛应用于情感分析或文本分类任务的流行数据集。它包含了从互联网电影数据库 (IMDB) 提取的 50,000 条电影评论，平均分配为 25,000 条用于训练，25,000 条用于测试。每条评论都被标记为正面或负面，这使其成为一个理想的二元数据集，适用于训练能够理解和分类文本中所表达观点的情感分类模型。它在自然语言处理社区被广泛用于比较情感分类模型的性能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b70480c0-4fa3-41a0-94b4-3a0d5fbf1b0a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(\"'Stanley and Iris' show the triumph of the human spirit. For Stanley, it's the struggle to become literate and realize his potential. For Iris, it's to find the courage to love again after becoming a widow. The beauty of the movie is the dance that Robert DeNiro and Jane Fonda do together, starting and stopping, before each has the skills and courage to completely trust each other and move on. In that sense it very nicely gives us a good view of how life often is, thus being credible. Unlike some other reviewers I found the characters each rendered to be consistent for the whole picture. The supporting cast is also carefully chosen and they add a depth of character that the main characters get added meaning from the supporting performances. All in all an excellent movie. The best thing I take from it is Hope.\", 1)\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.dataset import load_dataset\n",
    "import random\n",
    "random.seed(42)\n",
    "# 加载完整数据集\n",
    "imdb_ds = load_dataset('imdb', split=['train', 'test'])\n",
    "imdb_train = imdb_ds['train']\n",
    "imdb_test = imdb_ds['test']\n",
    "\n",
    "imdb_train, imdb_val = imdb_train.split([0.7, 0.3])\n",
    "print(imdb_train[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58456c9b-d523-40eb-9bfa-e7c8d2934af9",
   "metadata": {},
   "source": [
    "## 三、数据预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9fc1eea-faa0-443e-8d4d-2ed218c71941",
   "metadata": {},
   "source": [
    "首先，它定义了一个内部的 tokenize 函数，该函数利用传入的 tokenizer 将文本数据转换为模型所需的 input_ids（词元ID）、token_type_ids（片段ID）和 attention_mask（注意力掩码），同时根据 max_seq_len 控制序列长度。然后，函数会根据 shuffle 参数决定是否打乱数据集顺序，并根据 take_len 参数选择是否只取数据集的一部分。接着，它将 tokenize 函数应用于数据集的 \"text\" 列，生成新的特征列。随后，它将标签列 \"label\" 的数据类型转换为整型并重命名为 \"labels\"。最后，也是非常重要的一步，函数使用 padded_batch 方法将数据集分批，并对批次内长度不一的序列进行填充（padding），使得每个批次中的所有样本在 input_ids、token_type_ids 和 attention_mask 这些维度上都具有相同的长度，填充值由 tokenizer.pad_token_id 和0指定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8b890e36-1f4f-4af0-9129-82b27e4a6cbf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import mindspore\n",
    "import numpy as np\n",
    "from mindspore.dataset import transforms\n",
    "\n",
    "def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False, take_len=None):\n",
    "    # The tokenize function\n",
    "    def tokenize(text):\n",
    "        tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)\n",
    "        return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']\n",
    "\n",
    "    # Shuffle the order of the dataset\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(buffer_size=batch_size)\n",
    "\n",
    "    # Select the first several entries of the dataset\n",
    "    if take_len:\n",
    "        dataset = dataset.take(take_len)\n",
    "\n",
    "    # Apply the tokenize function, transforming the 'text' column into the three output columns generated by the tokenizer.\n",
    "    dataset = dataset.map(operations=[tokenize], input_columns=\"text\", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])\n",
    "    # Cast the datatype of the 'label' column to int32 and rename the column to 'labels'\n",
    "    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns=\"label\", output_columns=\"labels\")\n",
    "    # Batch the dataset with padding.\n",
    "    dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                         'token_type_ids': (None, 0),\n",
    "                                                         'attention_mask': (None, 0)})\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "541a5a62-bbd7-40dd-864d-84200d0a5c3b",
   "metadata": {},
   "source": [
    "首先从 mindnlp.transformers 库导入并加载了\"google/canine-s\" 的预训练 CanineTokenizer 分词器实例。然后，定义了批处理大小 batch_size 为 4，并计算了一个 take_len 变量（值为 800），用于从完整训练数据集中选取一个子集。接着，代码调用了先前定义的 process_dataset 函数，分别对 imdb_train 和 imdb_test 数据集进行预处理：对于训练集，它创建了一个名为 small_dataset_train 的小型、打乱顺序并按批次处理的数据集，其中包含了 take_len（即800）条样本；对于测试集，创建了 small_dataset_test，但只选取了400条样本进行打乱和批处理，这两个经过处理的小型数据集随后可以用于模型的快速训练和评估。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d4413b3f-e261-403e-b5e3-6bda9f7a775b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import CanineTokenizer\n",
    "tokenizer = CanineTokenizer.from_pretrained(\"google/canine-s\")\n",
    "batch_size = 4\n",
    "take_len = batch_size * 25\n",
    "small_dataset_train = process_dataset(imdb_train, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)\n",
    "small_dataset_test = process_dataset(imdb_test, tokenizer, batch_size=batch_size, shuffle=True, take_len=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c004c9b-a154-4146-9f7c-62fecb688cc8",
   "metadata": {},
   "source": [
    "## 四、模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bde5b80-f33a-44e9-bdb5-91440a1a883a",
   "metadata": {},
   "source": [
    "首先从 mindnlp.transformers 库中导入了 CanineConfig 类和 CanineForSequenceClassification 类，这两个类分别用于定义CANINE模型的配置和构建序列分类模型。接着，代码创建了一个 CanineConfig 的实例，并详细设置了模型的各项超参数，如隐藏层大小、注意力头数量、激活函数、dropout概率以及针对当前二分类任务的标签数量（num_labels=2）等。最后，使用这个配置好的 config 对象实例化了一个 CanineForSequenceClassification 模型，这个模型将基于CANINE架构并适用于序列分类任务，并通过 print(type(model)) 打印出所创建模型的类型，以确认模型已成功加载。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56b00527-8035-4f18-b124-31a79af8e376",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] DEVICE(10598,fffebaffe120,python):2025-05-22-12:07:48.245.533 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_vmm_adapter.h:147] CheckVmmDriverVersion] Open file /etc/ascend_install.info failed.\n",
      "[WARNING] DEVICE(10598,fffebaffe120,python):2025-05-22-12:07:48.245.625 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_vmm_adapter.h:186] CheckVmmDriverVersion] Driver version is less than 24.0.0, vmm is disabled by default, drvier_version: 23.0.6\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'mindnlp.transformers.models.canine.modeling_canine.CanineForSequenceClassification'>\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import CanineConfig, CanineForSequenceClassification\n",
    "\n",
    "config = CanineConfig(\n",
    "                    hidden_size=32,\n",
    "                    num_hidden_layers=2,\n",
    "                    num_attention_heads=4,\n",
    "                    intermediate_size=37,\n",
    "                    hidden_act=\"gelu\",\n",
    "                    hidden_dropout_prob=0.1,\n",
    "                    attention_probs_dropout_prob=0.1,\n",
    "                    max_position_embeddings=512,\n",
    "                    type_vocab_size=16,\n",
    "                    num_labels=2,  # 分类任务的类别数\n",
    "                )\n",
    "model = CanineForSequenceClassification(config)\n",
    "\n",
    "print(type(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b25acaf-7d97-4b0f-b5ef-f2bf497c2447",
   "metadata": {},
   "source": [
    "使用 mindnlp.engine 中的 TrainingArguments 类来配置模型训练过程中的各种参数。首先指定了训练过程中的输出目录为当前路径下的 \"output\" 文件夹，用于保存模型、检查点和日志等。接着，设置了每个设备上的训练批次大小和评估批次大小均为1，学习率为2e-5，总训练轮次为6轮。此外，还配置了日志记录的频率（每200步记录一次），并设定评估策略和保存策略均为在每个训练轮次（epoch）结束时执行。这些参数共同定义了模型训练的行为和流程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9ef76c85-7fa4-4883-b224-e336a670dd50",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindnlp.engine import TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    \"./output\",\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=1,\n",
    "    logging_steps=50,\n",
    "    evaluation_strategy='epoch',\n",
    "    save_strategy='epoch'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "873fd8d7-123f-4933-8437-4011348641af",
   "metadata": {},
   "source": [
    "代码定义了一个用于在模型评估过程中计算指标（具体来说是准确率）的函数 compute_metrics。首先，它从 evaluate 库加载了 \"accuracy\"（准确率）评估指标，并从 mindnlp.engine.utils 导入了 EvalPrediction 类型，该类型用于封装模型的预测输出和真实标签。然后，compute_metrics 函数接收一个 EvalPrediction 对象（包含模型的原始输出 logits 和真实 labels），通过 np.argmax 函数找到 logits 中概率最高的类别作为预测结果，最后利用加载的 metric 对象计算并返回预测结果与真实标签之间的准确率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e30ddb7a-1b5f-4db6-8f49-eaa93a800ab2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "from mindnlp.engine.utils import EvalPrediction\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred: EvalPrediction):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46782c29-cf8f-4a54-aa15-9f8d3e290f8d",
   "metadata": {},
   "source": [
    "代码初始化了一个 mindnlp.engine 中的 Trainer 对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a067e895-0664-4b21-9374-128bca31e07b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindnlp.engine import Trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=small_dataset_train,\n",
    "    eval_dataset=small_dataset_test,\n",
    "    compute_metrics=compute_metrics,\n",
    "    args=training_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bfa95a65-8bf8-41af-a17d-73997bc8e993",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13d754d548eb452c987cf4e31a1715da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "......"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.6978093981742859, 'eval_accuracy': 0.0, 'eval_runtime': 8.6984, 'eval_samples_per_second': 1.495, 'eval_steps_per_second': 1.495, 'epoch': 1.0}\n",
      "{'train_runtime': 89.1989, 'train_samples_per_second': 1.121, 'train_steps_per_second': 0.28, 'train_loss': 0.694559326171875, 'epoch': 1.0}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=25, training_loss=0.694559326171875, metrics={'train_runtime': 89.1989, 'train_samples_per_second': 1.121, 'train_steps_per_second': 0.28, 'train_loss': 0.694559326171875, 'epoch': 1.0})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f133b647-9c46-48d1-8dd8-55fd868fa79d",
   "metadata": {},
   "source": [
    "## 五、模型评估"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9727e74f-4e9f-45a0-8cd2-5b2dc04a732e",
   "metadata": {},
   "source": [
    "使用训练好的CANINE模型对新的单条文本 \"I absolutely love this movie\" 进行情感预测。它首先利用之前加载的 tokenizer 将文本转换为模型所需的数值化输入，并确保其长度符合要求且转换为MindSpore的 Tensor 格式，同时增加了一个批次维度。随后，代码将模型设置为评估模式，并将处理好的输入传递给模型以获取预测结果。运行结果 SequenceClassifierOutput(loss=None, logits=Tensor(shape=[1, 2], dtype=Float32, value=[[-0.149447203, -0.195080027]]), hidden_states=None, attentions=None) 表明模型成功输出了 logits，这是一个包含两个值的张量，分别代表模型预测该文本属于两个类别（例如负面和正面）的原始分数；在这个具体的例子中，第一个值略高于第二个值，暗示模型可能将该积极情感的文本归类为与第一个索引对应的类别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "557ec336-720d-4834-9c67-b0c304bce143",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SequenceClassifierOutput(loss=None, logits=Tensor(shape=[1, 2], dtype=Float32, value=\n",
      "[[-1.98532678e-02, -1.04218852e-02]]), hidden_states=None, attentions=None)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from mindspore import Tensor, ops\n",
    "\n",
    "text = \"I absolutely love this movie\"\n",
    "\n",
    "# Tokenize the text\n",
    "inputs = tokenizer(text, padding=True, truncation=True, max_length=256)\n",
    "ts_inputs = {key: Tensor(val).expand_dims(0) for key, val in inputs.items()}\n",
    "\n",
    "# Predict\n",
    "model.set_train(False)\n",
    "outputs = model(**ts_inputs)\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b8e650-2118-4936-a2c0-71d5544a65f2",
   "metadata": {},
   "source": [
    "首先通过 ops.softmax 函数将模型输出的 logits 转换为概率分布，然后从中提取每个类别的概率值，并使用 argmax() 确定概率最高的类别索引。接着，利用一个预设的 labels 字典（0代表'neg'，1代表'pos'）将此索引转换为对应的文本标签。最后，代码打印出计算得到的负面情感概率（0.5114）和正面情感概率（0.4886），以及最终预测的类别为 'neg'（对应索引0）。尽管输入文本 \"I absolutely love this movie\" 明显是正面情感，但运行结果显示模型预测其为负面，这说明模型在此特定样本上的预测与预期不符，可能需要进一步优化或分析。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1a93e2e3-7f95-4cd6-b541-9a68b7c2d402",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative sentiment: 0.4976\n",
      "Positive sentiment: 0.5024\n",
      "Predicted class: pos\n",
      "Predicted label: 1\n"
     ]
    }
   ],
   "source": [
    "# Convert predictions to probabilities\n",
    "predictions = ops.softmax(outputs.logits)\n",
    "probabilities = predictions.numpy().flatten()\n",
    "\n",
    "# Get predicted class index\n",
    "pred_class_idx = probabilities.argmax()\n",
    "\n",
    "# Map predicted class index to label\n",
    "labels = {0: 'neg', 1: 'pos'}\n",
    "pred_class_label = labels[pred_class_idx]\n",
    "\n",
    "# Output the probabilities and the predicted class label\n",
    "print(f\"Negative sentiment: {probabilities[0]:.4f}\")\n",
    "print(f\"Positive sentiment: {probabilities[1]:.4f}\")\n",
    "print(f\"Predicted class: {pred_class_label}\")\n",
    "print(f\"Predicted label: {pred_class_idx}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
